rm(list=objects())
setwd("~/Desktop/Experiments/Simulations/Results")
for (type in c(1,2,3)){
  ##################### -  Set parameters - #########
  K <- 3
  M <- 100 # number of experiments############
  N_list <- round(exp(seq(log(50), log(1000), length.out = 10)))
  
  # set Q the probabilities of connections between communities
  if (type == 1){ # assortative SBM
    Q <- matrix(c(0.5,   0.2,  0.2,
                  0.2,   0.5,  0.2,
                  0.2,  0.2,  0.5), nrow = 3)
  }else if (type == 2){ # disassortative SBM
    Q <- matrix(c( 0.2,  0.5,  0.5,
                   0.5,  0.2,  0.5,
                   0.5,  0.5,  0.2), nrow = 3)
  }else{ # mixed model
    Q <- matrix(c( 0.1,  0.5,  0.3,
                   0.5,  0.2,  0.4,
                   0.3,  0.4,  0.6), nrow = 3)
  }
  
  # set alpha the populations of the communities
  if (type %in% 1:2){
    alpha <- rep(1/K, K)
  }else{
    alpha <- c(0.1, 0.3, 0.6)
  }
  
  ##################### -  Experiments - #####################
  for (N in N_list){
    set.seed(0)
    start_time <- Sys.time()
    
    # record the errors 
    Error_missSBM <- rep(NA, M)
    Error_Var<- rep(0, M)
    Error_true_Z <- rep(NA, M)
    Error_softImpute <- rep(NA, M)
    for (m in 1:M){
      
      ##################### -  Draw random variables - #####################
      # set Z the label function
      Z <- base::sample(1:K, size = N, replace = T, prob = alpha)
      
      # set A the full adjacency matrix
      Theta <- sapply(1:N, function(i) sapply(1:N, function(j) Q[Z[i], Z[j]])) # SBM matrix
      diag(Theta) <- 0
      A_undir <- rbinom(n=N*(N-1)/2, size=1, prob=Theta[upper.tri(Theta)]) # draw edges
      A <- matrix(0,N,N)
      A[upper.tri(A)] <- A_undir
      A <- (A+t(A)) # adjacency matrix
      
      # Sample edges
      Omega_undir <- rbinom(n=N*(N-1)/2, size=1, prob=0.5) # draw edges
      Omega <- matrix(0,N,N)
      Omega[upper.tri(Omega)] <- Omega_undir
      Omega <- (Omega+t(Omega)) # adjacency matrix
      
      # set A_obs the observed adjacency matrix
      A_obs <- A
      diag(A_obs) <- NA
      A_obs[Omega == 0] <- NA
      
      ##################### -  Estimate network using softImpute - #####################
      SVD <- softImpute(A_obs, rank.max = K, lambda = 0, maxit = 500)
      estimate_Theta_softImpute <- SVD$u %*% diag(SVD$d, nrow = K, ncol = K) %*% t(SVD$v)
      estimate_Theta_softImpute <- pmin(pmax(estimate_Theta_softImpute, 0),1)
      Error_softImpute[m] <- sum((estimate_Theta_softImpute - Theta)**2, na.rm = F)/N**2
      
      ##################### -  Estimate network using missSBM - #####################
      estimator_missSBM <- missSBM::estimateMissSBM(
        adjacencyMatrix = A_obs, 
        vBlocks = c(K),
        sampling = "dyad",
        control = list(trace = 0))$bestModel$fittedSBM
      estimate_Theta_missSBM <- estimator_missSBM$expectation
      
      Error_missSBM[m] <- sum((estimate_Theta_missSBM - Theta)**2, na.rm = F)/N**2
      
      ##################### -  Estimate network using the variational estimate of z - #####################
      z_est <- estimator_missSBM$memberships
      estimate_Q_Var <- sapply(1:K, function(a) sapply(1:K, function(b) mean(A_obs[z_est == a, z_est == b], na.rm =T)))
      estimate_Q_Var[is.na(estimate_Q_Var)] <- 0
      estimate_Theta_Var <- sapply(1:N, function(i) sapply(1:N, function(j) estimate_Q_Var[z_est[i], z_est[j]]))
      diag(estimate_Theta_Var) <- 0
      
      Error_Var[m] <- sum((estimate_Theta_Var - Theta)**2, na.rm = F)/N**2
      
      ##################### -  Estimate network using the true z - #####################
      estimate_Q_true_Z <- sapply(1:K, function(a) sapply(1:K, function(b) mean(A_obs[Z == a, Z == b], na.rm =T)))
      estimate_Q_true_Z[is.na(estimate_Q_true_Z)] <- 0
      estimate_Theta_true_Z <- sapply(1:N, function(i) sapply(1:N, function(j) estimate_Q_true_Z[Z[i], Z[j]]))
      diag(estimate_Theta_true_Z) <- 0
      
      Error_true_Z[m] <- sum((estimate_Theta_true_Z - Theta)**2, na.rm = F)/N**2
    }
    
    results <- list(N = N, K = K, fixed_K = TRUE, type = type,
                    Error_softImpute = Error_softImpute, Error_missSBM = Error_missSBM,
                    Error_Var = Error_Var, Error_true_Z = Error_true_Z)
    
    path <- "Fixed K"
    
    if(type == 1){
      path <- paste0(path, "/Assortative SBM")
    }else if(type == 2){
      path <- paste0(path, "/Disassortative SBM")
    }else{
      path <- paste0(path, "/Mixed SBM")
    }
    path <- paste0(path, "/N_",N ,".RDS")
    saveRDS(results, file = path)
    
    print(path)
    end_time <- Sys.time()
    print(end_time - start_time)
  }
}
